import torch
import torchvision.transforms.functional as TF
from utils.loss import remd_loss, remd_loss_m1, remd_loss_m2, remd_loss_m2_mse, remd_loss_m2_max, \
                  remd_loss_m2_meanmax, remd_loss_m2_thresh, remd_loss_m2_focal, remd_loss_mse
import einops
import math
from utils.gaussian_smoothing import GaussianSmoothing

import torch.nn as nn
import torch.nn.functional as F

 
smth_3 = GaussianSmoothing(sigma=3.0).cuda()


sobel_x = torch.tensor([[1, 0, -1],
                        [2, 0, -2],
                        [1, 0, -1]], dtype=torch.float32).cuda()

sobel_y = torch.tensor([[1, 2, 1],
                        [0, 0, 0],
                        [-1, -2, -1]], dtype=torch.float32).cuda()

sobel_x = sobel_x.view(1, 1, 3, 3)
sobel_y = sobel_y.view(1, 1, 3, 3)

sobel_conv_x = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
sobel_conv_y = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)


sobel_conv_x.weight = nn.Parameter(sobel_x)
sobel_conv_y.weight = nn.Parameter(sobel_y)

def calculate_sobel(attn):
    # attn_map_clone = F.interpolate(attn_map.unsqueeze(0).unsqueeze(0), 66, mode='bilinear')
    attn_map_clone = attn.unsqueeze(0).unsqueeze(0)
    attn_map_clone = SelfGuidanceEdits._attn_diff_norm(attn_map_clone).detach()
    # attn_map_clone = attn_map_clone / attn_map_clone.max().detach()
    # attn_map_clone = F.pad(attn_map_clone, (1, 1, 1, 1), mode='reflect')
    # attn_map_clone = smth_3(attn_map_clone)

    sobel_output_x = sobel_conv_x(attn_map_clone).squeeze()[1:-1, 1:-1]
    sobel_output_y = sobel_conv_y(attn_map_clone).squeeze()[1:-1, 1:-1]
    sobel_sum = torch.sqrt(sobel_output_y ** 2  + sobel_output_x ** 2)
    sobel_sum = sobel_sum #/ attn_map[1:-1, 1:-1]
    return sobel_sum # (1, 62, 62, 1)

def Fourier_filter(x, threshold, scale):
    # FFT
    x_freq = torch.fft.fftn(x, dim=(-2, -1))
    x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
    
    B, C, H, W = x_freq.shape
    mask = torch.ones((B, C, H, W)).cuda() 

    crow, ccol = H // 2, W //2
    mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
    x_freq = x_freq * mask

    # IFFT
    x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
    x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real
    
    return x_filtered

class SelfGuidanceEdits:
  @staticmethod
  def _centroid(a):
    x = torch.linspace(0, 1, a.shape[-2]).to(a.device)
    y = torch.linspace(0, 1, a.shape[-3]).to(a.device)
    # a is (n, h, w, k)
    attn_x = a.sum(-3)  # (n, w, k)
    attn_y = a.sum(-2)  # (n, h, k)

    def f(_attn, _linspace):
      _attn = _attn / (_attn.sum(-2, keepdim=True) + 1e-4)  # (n, 1, k)
      _weighted_attn = (
          _linspace[None, ..., None] * _attn
      )  # (n, h or w, k)
      return _weighted_attn.sum(-2)  # (n, k)

    centroid_x = f(attn_x, x)
    centroid_y = f(attn_y, y)
    centroid = torch.stack((centroid_x, centroid_y), -1)  # (n, k, 2)
    return centroid
  
  @staticmethod
  def _attn_softmax(report_attn, hard=False, thresh=0.5):
    attn_min = report_attn.min()
    attn_max = report_attn.max()
    attn_softmax = torch.exp(report_attn - attn_max) 
    attn_softmax = attn_softmax / torch.sum(attn_softmax)
    if hard:
      return (attn_softmax>thresh)*1.0
    return attn_softmax

  @staticmethod
  def _attn_diff_norm(report_attn, hard=False, thresh=0.5):
    # attn_min = report_attn.min(2,keepdim=True)[0].min(3,keepdim=True)[0]
    # attn_max = report_attn.max(2,keepdim=True)[0].max(3,keepdim=True)[0]
    attn_min = report_attn.min()
    attn_max = report_attn.max()
    attn_thresh = (report_attn - attn_min) / (attn_max - attn_min + 1e-4)
    if hard:
      return (attn_thresh>thresh)*1.0
    attn_binarized = torch.sigmoid((attn_thresh-thresh)*10)
    # attn_min = attn_binarized.min(2,keepdim=True)[0].min(3,keepdim=True)[0]
    # attn_max = attn_binarized.max(2,keepdim=True)[0].max(3,keepdim=True)[0]
    attn_min = attn_binarized.min()
    attn_max = attn_binarized.max()
    attn_norm = (attn_binarized - attn_min) / (attn_max - attn_min + 1e-4)
    return attn_norm

  @staticmethod
  def _attn_diff_norm_soft(report_attn, hard=False, thresh=0.5):
    # attn_min = report_attn.min(2,keepdim=True)[0].min(3,keepdim=True)[0]
    # attn_max = report_attn.max(2,keepdim=True)[0].max(3,keepdim=True)[0]
    base_thresh = 0.1
    attn_min = report_attn.min()
    attn_max = report_attn.max()
    attn_thresh = (report_attn - attn_min) / (attn_max - attn_min + 1e-4)
    attn_thresh_flatten = attn_thresh.flatten()
    k = int(attn_thresh_flatten[attn_thresh_flatten > base_thresh].shape[0] * (1 - thresh))
    kth_value, _ = torch.kthvalue(attn_thresh_flatten[attn_thresh_flatten > base_thresh], k)
    if hard:
      return (attn_thresh>kth_value)*1.0
    attn_binarized = torch.sigmoid((attn_thresh-kth_value)*10)
    # attn_min = attn_binarized.min(2,keepdim=True)[0].min(3,keepdim=True)[0]
    # attn_max = attn_binarized.max(2,keepdim=True)[0].max(3,keepdim=True)[0]
    attn_min = attn_binarized.min()
    attn_max = attn_binarized.max()
    attn_norm = (attn_binarized - attn_min) / (attn_max - attn_min + 1e-4)
    return attn_norm

  # TODO
  @staticmethod
  def mean_app(aux,i, tgt, idxs=None, L2=False, thresh=0.5):
    tgt_aux = tgt
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    aux = {key:next(iter(v.values()))[i] for (k,v), key in zip(aux.items(), ['_attn', '_feats'])}
    tgt_aux = {key:next(iter(v.values()))[i].detach().to(dev) for (k,v), key in zip(tgt_aux.items(), ['_attn', '_feats'])}
    if idxs is not None:
      aux['_attn'] = aux['_attn'][..., idxs]
      tgt_aux['_attn'] = tgt_aux['_attn'][..., idxs]
    def _compute(_feats, _attn):
      _attn = _attn.detach()
      _feats = _feats.permute(0,2,3,1)

      _attn = TF.resize(_attn.permute(0,3,1,2), _feats.shape[1], antialias=True).permute(0,2,3,1) # TODO VERIFY SHAPE 1
      _attn = SelfGuidanceEdits._attn_diff_norm(_attn, hard=True, thresh=thresh)

      # import pdb; pdb.set_trace()
      app = (_attn[...,None]*_feats[...,None,:]).sum((-3,-4))/(1e-4+_attn.sum((-2,-3))[...,None]) # TODO MULTI AXIS SUM LIKE THIS?

      return app
    obs_app = _compute(**aux)
    tgt_app = _compute(**tgt_aux)

    if L2: return (0.5*(obs_app-tgt_app)**2).mean()
    return (obs_app-tgt_app).abs().mean()
  
  # TODO
  @staticmethod
  def match(aux,i, tgt, ref_mask, base_mask, map, idxs=None, L2=False, thresh=0.5):
    tgt_aux = tgt
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    aux = {key:next(iter(v.values()))[i] for (k,v), key in zip(aux.items(), ['_attn', '_feats'])}
    tgt_aux = {key:next(iter(v.values()))[i].detach().to(dev) for (k,v), key in zip(tgt_aux.items(), ['_attn', '_feats'])}
    if idxs is not None:
      aux['_attn'] = aux['_attn'][..., idxs]
      tgt_aux['_attn'] = tgt_aux['_attn'][..., idxs]
    def _compute(_feats, _attn):
      _attn = _attn.detach()
      _feats = _feats.permute(0,2,3,1)
      
      attn = _attn.reshape(_attn.shape[0], -1, _attn.shape[-1])
      feats = _feats.reshape(_feats.shape[0], -1, _feats.shape[-1])
      feats = _feats.reshape(_feats.shape[0], -1, _feats.shape[-1])

      return feats[(attn>0).squeeze(-1)]
    obs_feats = _compute(aux['_feats'], base_mask).squeeze(0)
    tgt_feats = _compute(tgt_aux['_feats'], ref_mask).squeeze(0)

    feats1, feats2 = (obs_feats, tgt_feats[map]) 
    
    # bg_loss = 
    loss = ((feats1 - feats2)**2).mean()
    return loss


  # TODO
  @staticmethod
  def remd(aux,i, tgt, idxs=None, L2=False, thresh=0.5):
    tgt_aux = tgt
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    aux = {key:next(iter(v.values()))[i] for (k,v), key in zip(aux.items(), ['_attn', '_feats'])}
    tgt_aux = {key:next(iter(v.values()))[i].detach().to(dev) for (k,v), key in zip(tgt_aux.items(), ['_attn', '_feats'])}
    if idxs is not None:
      aux['_attn'] = aux['_attn'][..., idxs]
      tgt_aux['_attn'] = tgt_aux['_attn'][..., idxs]
    def _compute(_feats, _attn):
      _attn = _attn.detach()
      _feats = _feats.permute(0,2,3,1)

      _attn = TF.resize(_attn.permute(0,3,1,2), _feats.shape[1], antialias=True).permute(0,2,3,1) # TODO VERIFY SHAPE 1
      _attn = SelfGuidanceEdits._attn_diff_norm(_attn, hard=True, thresh=thresh)

      # import pdb; pdb.set_trace()
      attn = _attn.reshape(_attn.shape[0], -1, _attn.shape[-1])
      feats = _feats.reshape(_feats.shape[0], -1, _feats.shape[-1])

      return feats[(attn>0).squeeze(-1)]
    obs_feats = _compute(**aux)
    tgt_feats = _compute(**tgt_aux)

    loss_remd = remd_loss(tgt_feats,obs_feats, h=None, cos_d=True, splits= [obs_feats.shape[-1]],return_mat=False)[0]
    # bg_loss = 
    loss = loss_remd
    return loss

  # TODO
  @staticmethod
  def remd_mse(aux,i, tgt, idxs=None, L2=False, thresh=0.5):
    tgt_aux = tgt
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    aux = {key:next(iter(v.values()))[i] for (k,v), key in zip(aux.items(), ['_attn', '_feats'])}
    tgt_aux = {key:next(iter(v.values()))[i].detach().to(dev) for (k,v), key in zip(tgt_aux.items(), ['_attn', '_feats'])}
    if idxs is not None:
      aux['_attn'] = aux['_attn'][..., idxs]
      tgt_aux['_attn'] = tgt_aux['_attn'][..., idxs]
    def _compute(_feats, _attn):
      _attn = _attn.detach()
      _feats = _feats.permute(0,2,3,1)

      _attn = TF.resize(_attn.permute(0,3,1,2), _feats.shape[1], antialias=True).permute(0,2,3,1) # TODO VERIFY SHAPE 1
      _attn = SelfGuidanceEdits._attn_diff_norm(_attn, hard=True, thresh=thresh)

      # import pdb; pdb.set_trace()
      attn = _attn.reshape(_attn.shape[0], -1, _attn.shape[-1])
      feats = _feats.reshape(_feats.shape[0], -1, _feats.shape[-1])

      return feats[(attn>0).squeeze(-1)]
    obs_feats = _compute(**aux)
    tgt_feats = _compute(**tgt_aux)

    loss_remd = remd_loss_mse(tgt_feats,obs_feats, h=None, cos_d=True, splits= [obs_feats.shape[-1]],return_mat=False)[0]
    # bg_loss = 
    loss = loss_remd
    return loss
  
  # TODO
  @staticmethod
  def remd2_mask(aux,i, tgt, idxs=None, L2=False, thresh=0.5):
    tgt_aux = tgt
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    aux = {key:next(iter(v.values()))[i] for (k,v), key in zip(aux.items(), ['_attn', '_feats'])}
    tgt_aux = {key:next(iter(v.values()))[i].detach().to(dev) for (k,v), key in zip(tgt_aux.items(), ['_attn', '_feats'])}
    if idxs is not None:
      aux['_attn'] = aux['_attn'][..., idxs]
      tgt_aux['_attn'] = tgt_aux['_attn'][..., idxs]
    def _compute(_feats, _attn, _mask):
      _attn = _attn.detach()
      _feats = _feats.permute(0,2,3,1)

      _attn = TF.resize(_attn.permute(0,3,1,2), _feats.shape[1], antialias=True).permute(0,2,3,1) # TODO VERIFY SHAPE 1
      # _attn = SelfGuidanceEdits._attn_diff_norm(_attn)
      _attn = SelfGuidanceEdits._attn_diff_norm(_attn, hard=True, thresh=thresh)

      _mask = TF.resize(_mask, _feats.shape[1], antialias=True)
      _mask = SelfGuidanceEdits._attn_diff_norm(_mask, hard=True, thresh=thresh).permute(0,2,3,1)

      # import pdb; pdb.set_trace()
      attn = _attn.reshape(_attn.shape[0], -1, _attn.shape[-1])
      mask = _mask.reshape(_mask.shape[0], -1, _mask.shape[-1])
      feats = _feats.reshape(_feats.shape[0], -1, _feats.shape[-1])

      region = (attn>0) & (mask>0)

      return feats[(region>0).squeeze(-1)]
    global mask_base, mask_ref
    aux['_mask'] = mask_base.unsqueeze(0).unsqueeze(0).to(dev)
    tgt_aux['_mask'] = mask_ref.unsqueeze(0).unsqueeze(0).to(dev)
    obs_feats = _compute(**aux)
    tgt_feats = _compute(**tgt_aux)

    loss_remd = remd_loss_m2(tgt_feats,obs_feats, h=None, cos_d=True, splits= [obs_feats.shape[-1]],return_mat=False)[0]
    # bg_loss = 
    loss = loss_remd
    return loss

  # TODO
  @staticmethod
  def remd2(aux,i, tgt, idxs=None, L2=False, thresh=0.5):
    tgt_aux = tgt
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    aux = {key:next(iter(v.values()))[i] for (k,v), key in zip(aux.items(), ['_attn', '_feats'])}
    tgt_aux = {key:next(iter(v.values()))[i].detach().to(dev) for (k,v), key in zip(tgt_aux.items(), ['_attn', '_feats'])}
    if idxs is not None:
      aux['_attn'] = aux['_attn'][..., idxs]
      tgt_aux['_attn'] = tgt_aux['_attn'][..., idxs]
    def _compute(_feats, _attn):
      _attn = _attn.detach()
      _feats = _feats.permute(0,2,3,1)

      _attn = TF.resize(_attn.permute(0,3,1,2), _feats.shape[1], antialias=True).permute(0,2,3,1) # TODO VERIFY SHAPE 1
      _attn = SelfGuidanceEdits._attn_diff_norm(_attn, hard=True, thresh=thresh)

      # import pdb; pdb.set_trace()
      attn = _attn.reshape(_attn.shape[0], -1, _attn.shape[-1])
      feats = _feats.reshape(_feats.shape[0], -1, _feats.shape[-1])

      return feats[(attn>0).squeeze(-1)]
    obs_feats = _compute(**aux)
    tgt_feats = _compute(**tgt_aux)

    loss_remd = remd_loss_m2(tgt_feats,obs_feats, h=None, cos_d=True, splits= [obs_feats.shape[-1]],return_mat=False)[0]
    # bg_loss = 
    loss = loss_remd
    return loss

  # TODO
  @staticmethod
  def remd2_thresh(aux,i, tgt, idxs=None, L2=False, thresh=0.5):
    tgt_aux = tgt
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    aux = {key:next(iter(v.values()))[i] for (k,v), key in zip(aux.items(), ['_attn', '_feats'])}
    tgt_aux = {key:next(iter(v.values()))[i].detach().to(dev) for (k,v), key in zip(tgt_aux.items(), ['_attn', '_feats'])}
    if idxs is not None:
      aux['_attn'] = aux['_attn'][..., idxs]
      tgt_aux['_attn'] = tgt_aux['_attn'][..., idxs]
    def _compute(_feats, _attn):
      _attn = _attn.detach()
      _feats = _feats.permute(0,2,3,1)

      _attn = TF.resize(_attn.permute(0,3,1,2), _feats.shape[1], antialias=True).permute(0,2,3,1) # TODO VERIFY SHAPE 1
      _attn = SelfGuidanceEdits._attn_diff_norm(_attn, hard=True, thresh=thresh)

      # import pdb; pdb.set_trace()
      attn = _attn.reshape(_attn.shape[0], -1, _attn.shape[-1])
      feats = _feats.reshape(_feats.shape[0], -1, _feats.shape[-1])

      return feats[(attn>0).squeeze(-1)]
    obs_feats = _compute(**aux)
    tgt_feats = _compute(**tgt_aux)

    loss_remd = remd_loss_m2_thresh(tgt_feats,obs_feats, h=None, cos_d=True, splits= [obs_feats.shape[-1]],return_mat=False)[0]
    # bg_loss = 
    loss = loss_remd
    return loss

  # TODO
  @staticmethod
  def remd2_focal(aux,i, tgt, idxs=None, L2=False, thresh=0.5, gamma=3):
    tgt_aux = tgt
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    aux = {key:next(iter(v.values()))[i] for (k,v), key in zip(aux.items(), ['_attn', '_feats'])}
    tgt_aux = {key:next(iter(v.values()))[i].detach().to(dev) for (k,v), key in zip(tgt_aux.items(), ['_attn', '_feats'])}
    if idxs is not None:
      aux['_attn'] = aux['_attn'][..., idxs]
      tgt_aux['_attn'] = tgt_aux['_attn'][..., idxs]
    def _compute(_feats, _attn):
      _attn = _attn.detach()
      _feats = _feats.permute(0,2,3,1)

      _attn = TF.resize(_attn.permute(0,3,1,2), _feats.shape[1], antialias=True).permute(0,2,3,1) # TODO VERIFY SHAPE 1
      _attn = SelfGuidanceEdits._attn_diff_norm(_attn, hard=True, thresh=thresh)

      # import pdb; pdb.set_trace()
      attn = _attn.reshape(_attn.shape[0], -1, _attn.shape[-1])
      feats = _feats.reshape(_feats.shape[0], -1, _feats.shape[-1])

      return feats[(attn>0).squeeze(-1)]
    obs_feats = _compute(**aux)
    tgt_feats = _compute(**tgt_aux)

    loss_remd = remd_loss_m2_focal(tgt_feats,obs_feats, h=None, cos_d=True, splits= [obs_feats.shape[-1]],return_mat=False, gamma=gamma)[0]
    # bg_loss = 
    loss = loss_remd
    return loss

  # TODO
  @staticmethod
  def none(aux,i, tgt, idxs=None, L2=False, thresh=0.5):
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    return torch.tensor(0.).to(dev)
  

  # TODO
  @staticmethod
  def remd2_mse(aux,i, tgt, idxs=None, L2=False, thresh=0.5):
    tgt_aux = tgt
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    aux = {key:next(iter(v.values()))[i] for (k,v), key in zip(aux.items(), ['_attn', '_feats'])}
    tgt_aux = {key:next(iter(v.values()))[i].detach().to(dev) for (k,v), key in zip(tgt_aux.items(), ['_attn', '_feats'])}
    if idxs is not None:
      aux['_attn'] = aux['_attn'][..., idxs]
      tgt_aux['_attn'] = tgt_aux['_attn'][..., idxs]
    def _compute(_feats, _attn):
      _attn = _attn.detach()
      _feats = _feats.permute(0,2,3,1)

      _attn = TF.resize(_attn.permute(0,3,1,2), _feats.shape[1], antialias=True).permute(0,2,3,1) # TODO VERIFY SHAPE 1
      _attn = SelfGuidanceEdits._attn_diff_norm(_attn, hard=True, thresh=thresh)

      # import pdb; pdb.set_trace()
      attn = _attn.reshape(_attn.shape[0], -1, _attn.shape[-1])
      feats = _feats.reshape(_feats.shape[0], -1, _feats.shape[-1])

      return feats[(attn>0).squeeze(-1)]
    obs_feats = _compute(**aux)
    tgt_feats = _compute(**tgt_aux)

    loss_remd = remd_loss_m2_mse(tgt_feats,obs_feats, h=None, cos_d=True, splits= [obs_feats.shape[-1]],return_mat=False)[0]
    # bg_loss = 
    loss = loss_remd
    return loss

  # TODO
  @staticmethod
  def remd2_mse_high_freq(aux,i, tgt, idxs=None, L2=False, thresh=0.5):
    tgt_aux = tgt
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    aux = {key:next(iter(v.values()))[i] for (k,v), key in zip(aux.items(), ['_attn', '_feats'])}
    tgt_aux = {key:next(iter(v.values()))[i].detach().to(dev) for (k,v), key in zip(tgt_aux.items(), ['_attn', '_feats'])}
    if idxs is not None:
      aux['_attn'] = aux['_attn'][..., idxs]
      tgt_aux['_attn'] = tgt_aux['_attn'][..., idxs]
    def _compute(_feats, _attn):
      _attn = _attn.detach()
      _feats = _feats.permute(0,2,3,1)

      _attn = SelfGuidanceEdits._attn_diff_norm(_attn, hard=True, thresh=thresh)
      _attn = TF.resize(_attn.permute(0,3,1,2), _feats.shape[1], antialias=True).permute(0,2,3,1) # TODO VERIFY SHAPE 1

      # import pdb; pdb.set_trace()
      attn = _attn.reshape(_attn.shape[0], -1, _attn.shape[-1])
      feats = _feats.reshape(_feats.shape[0], -1, _feats.shape[-1])

      feats_freq = torch.fft.fft(feats, dim=-2)
      _, c, d = feats.shape
      feats_freq[:, :c//2, :] = 0.2
      feats_high = torch.fft.ifftn(feats_freq, dim=-2).real

      return feats_high[(attn>0).squeeze(-1)]
    obs_feats = _compute(**aux)
    tgt_feats = _compute(**tgt_aux)

    loss_remd = remd_loss_m2_mse(tgt_feats,obs_feats, h=None, cos_d=True, splits= [obs_feats.shape[-1]],return_mat=False)[0]
    # bg_loss = 
    loss = loss_remd
    return loss

  # TODO
  @staticmethod
  def remd2_max(aux,i, tgt, idxs=None, L2=False, thresh=0.5):
    tgt_aux = tgt
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    aux = {key:next(iter(v.values()))[i] for (k,v), key in zip(aux.items(), ['_attn', '_feats'])}
    tgt_aux = {key:next(iter(v.values()))[i].detach().to(dev) for (k,v), key in zip(tgt_aux.items(), ['_attn', '_feats'])}
    if idxs is not None:
      aux['_attn'] = aux['_attn'][..., idxs]
      tgt_aux['_attn'] = tgt_aux['_attn'][..., idxs]
    def _compute(_feats, _attn):
      _attn = _attn.detach()
      _feats = _feats.permute(0,2,3,1)

      _attn = TF.resize(_attn.permute(0,3,1,2), _feats.shape[1], antialias=True).permute(0,2,3,1) # TODO VERIFY SHAPE 1
      _attn = SelfGuidanceEdits._attn_diff_norm(_attn, hard=True, thresh=thresh)

      # import pdb; pdb.set_trace()
      attn = _attn.reshape(_attn.shape[0], -1, _attn.shape[-1])
      feats = _feats.reshape(_feats.shape[0], -1, _feats.shape[-1])

      return feats[(attn>0).squeeze(-1)]
    obs_feats = _compute(**aux)
    tgt_feats = _compute(**tgt_aux)

    loss_remd = remd_loss_m2_max(tgt_feats,obs_feats, h=None, cos_d=True, splits= [obs_feats.shape[-1]],return_mat=False)[0]
    # bg_loss = 
    loss = loss_remd
    return loss

  # TODO
  @staticmethod
  def remd2_meanmax(aux,i, tgt, idxs=None, L2=False, thresh=0.5):
    tgt_aux = tgt
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    aux = {key:next(iter(v.values()))[i] for (k,v), key in zip(aux.items(), ['_attn', '_feats'])}
    tgt_aux = {key:next(iter(v.values()))[i].detach().to(dev) for (k,v), key in zip(tgt_aux.items(), ['_attn', '_feats'])}
    if idxs is not None:
      aux['_attn'] = aux['_attn'][..., idxs]
      tgt_aux['_attn'] = tgt_aux['_attn'][..., idxs]
    def _compute(_feats, _attn):
      _attn = _attn.detach()
      _feats = _feats.permute(0,2,3,1)

      _attn = TF.resize(_attn.permute(0,3,1,2), _feats.shape[1], antialias=True).permute(0,2,3,1) # TODO VERIFY SHAPE 1
      _attn = SelfGuidanceEdits._attn_diff_norm(_attn, hard=True, thresh=thresh)

      # import pdb; pdb.set_trace()
      attn = _attn.reshape(_attn.shape[0], -1, _attn.shape[-1])
      feats = _feats.reshape(_feats.shape[0], -1, _feats.shape[-1])

      return feats[(attn>0).squeeze(-1)]
    obs_feats = _compute(**aux)
    tgt_feats = _compute(**tgt_aux)

    loss_remd = remd_loss_m2_meanmax(tgt_feats,obs_feats, h=None, cos_d=True, splits= [obs_feats.shape[-1]],return_mat=False)[0]
    # bg_loss = 
    loss = loss_remd
    return loss

  # TODO
  @staticmethod
  def remd1_high_freq(aux,i, tgt, idxs=None, L2=False, thresh=0.5):
    tgt_aux = tgt
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    aux = {key:next(iter(v.values()))[i] for (k,v), key in zip(aux.items(), ['_attn', '_feats'])}
    tgt_aux = {key:next(iter(v.values()))[i].detach().to(dev) for (k,v), key in zip(tgt_aux.items(), ['_attn', '_feats'])}
    if idxs is not None:
      aux['_attn'] = aux['_attn'][..., idxs]
      tgt_aux['_attn'] = tgt_aux['_attn'][..., idxs]
    def _compute(_feats, _attn):
      _attn = _attn.detach()
      _feats = _feats.permute(0,2,3,1)

      _attn = SelfGuidanceEdits._attn_diff_norm(_attn, hard=True, thresh=thresh)
      _attn = TF.resize(_attn.permute(0,3,1,2), _feats.shape[1], antialias=True).permute(0,2,3,1) # TODO VERIFY SHAPE 1

      # import pdb; pdb.set_trace()
      attn = _attn.reshape(_attn.shape[0], -1, _attn.shape[-1])
      feats = _feats.reshape(_feats.shape[0], -1, _feats.shape[-1])

      feats_freq = torch.fft.fft(feats, dim=-2)
      _, c, d = feats.shape
      feats_freq[:, :c//2, :] = 0
      feats_high = torch.fft.ifftn(feats_freq, dim=-2).real

      return feats_high[(attn>0).squeeze(-1)]
    obs_feats = _compute(**aux)
    tgt_feats = _compute(**tgt_aux)

    loss_remd = remd_loss_m1(tgt_feats,obs_feats, h=None, cos_d=True, splits= [obs_feats.shape[-1]],return_mat=False)[0]
    # bg_loss = 
    loss = loss_remd
    return loss

  # TODO
  @staticmethod
  def remd1_low_freq(aux,i, tgt, idxs=None, L2=False, thresh=0.5):
    tgt_aux = tgt
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    aux = {key:next(iter(v.values()))[i] for (k,v), key in zip(aux.items(), ['_attn', '_feats'])}
    tgt_aux = {key:next(iter(v.values()))[i].detach().to(dev) for (k,v), key in zip(tgt_aux.items(), ['_attn', '_feats'])}
    if idxs is not None:
      aux['_attn'] = aux['_attn'][..., idxs]
      tgt_aux['_attn'] = tgt_aux['_attn'][..., idxs]
    def _compute(_feats, _attn):
      _attn = _attn.detach()
      _feats = _feats.permute(0,2,3,1)

      _attn = SelfGuidanceEdits._attn_diff_norm(_attn, hard=True, thresh=thresh)
      _attn = TF.resize(_attn.permute(0,3,1,2), _feats.shape[1], antialias=True).permute(0,2,3,1) # TODO VERIFY SHAPE 1

      # import pdb; pdb.set_trace()
      attn = _attn.reshape(_attn.shape[0], -1, _attn.shape[-1])
      feats = _feats.reshape(_feats.shape[0], -1, _feats.shape[-1])

      feats_freq = torch.fft.fft(feats, dim=-2)
      _, c, d = feats.shape
      feats_freq[:, c//2:, :] = 0
      feats_high = torch.fft.ifftn(feats_freq, dim=-2).real

      return feats_high[(attn>0).squeeze(-1)]
    obs_feats = _compute(**aux)
    tgt_feats = _compute(**tgt_aux)

    loss_remd = remd_loss_m1(tgt_feats,obs_feats, h=None, cos_d=True, splits= [obs_feats.shape[-1]],return_mat=False)[0]
    # bg_loss = 
    loss = loss_remd
    return loss
  
  # TODO
  @staticmethod
  def remd1(aux,i, tgt, idxs=None, L2=False, thresh=0.5):
    tgt_aux = tgt
    dev = torch.utils._pytree.tree_flatten(aux)[0][-1].device
    aux = {key:next(iter(v.values()))[i] for (k,v), key in zip(aux.items(), ['_attn', '_feats'])}
    tgt_aux = {key:next(iter(v.values()))[i].detach().to(dev) for (k,v), key in zip(tgt_aux.items(), ['_attn', '_feats'])}
    if idxs is not None:
      aux['_attn'] = aux['_attn'][..., idxs]
      tgt_aux['_attn'] = tgt_aux['_attn'][..., idxs]
    def _compute(_feats, _attn):
      _attn = _attn.detach()
      _feats = _feats.permute(0,2,3,1)

      _attn = SelfGuidanceEdits._attn_diff_norm(_attn, hard=True, thresh=thresh)
      _attn = TF.resize(_attn.permute(0,3,1,2), _feats.shape[1], antialias=True).permute(0,2,3,1) # TODO VERIFY SHAPE 1

      # import pdb; pdb.set_trace()
      attn = _attn.reshape(_attn.shape[0], -1, _attn.shape[-1])
      feats = _feats.reshape(_feats.shape[0], -1, _feats.shape[-1])

      return feats[(attn>0).squeeze(-1)]
    obs_feats = _compute(**aux)
    tgt_feats = _compute(**tgt_aux)

    loss_remd = remd_loss_m1(tgt_feats,obs_feats, h=None, cos_d=True, splits= [obs_feats.shape[-1]],return_mat=False)[0]
    # bg_loss = 
    loss = loss_remd
    return loss

  @staticmethod
  def silhouette(attn, i, tgt, idxs=None, rot=0., sy=1., sx=1., dy=0., dx=0., thresh=True, rsz=None, L2=False):
    attn = attn[i]
    tgt_attn = tgt[i].to(attn.device)


    if idxs is not None:
      attn = attn[...,idxs]
      tgt_attn = tgt_attn[...,idxs]
    if rsz:
      attn = TF.resize(attn.permute(0,3,1,2), rsz, antialias=True).permute(0,2,3,1)
      tgt_attn = TF.resize(tgt_attn.permute(0,3,1,2), rsz, antialias=True).permute(0,2,3,1)
    if thresh:
      attn = SelfGuidanceEdits._attn_diff_norm(attn)
      tgt_attn = SelfGuidanceEdits._attn_diff_norm(tgt_attn, hard=True)
    transform = rot != 0 or any(_!=1. for _ in [sy,sx,dy,dx])
    if transform:
      ns,hs,ws,ks=tgt_attn.shape
      dev=attn.device
      n,h,w,k=torch.meshgrid(torch.arange(ns),torch.arange(ws),
                             torch.arange(hs), torch.arange(ks),indexing='ij')
      n,h,w,k=n.to(dev),h.to(dev),w.to(dev),k.to(dev)
      # centroid
      c = SelfGuidanceEdits._centroid(attn)
      ch = c[...,1][:,None,None]*hs
      cw = c[...,0][:,None,None]*ws
      # object centric coord system
      h = h - ch
      w = w - cw
      # rotate
      angle_deg_cw = rot
      th = angle_deg_cw * math.pi / 180
      wh = torch.stack((w,h), -1)[...,None]
      R = torch.tensor([[math.cos(th), math.sin(th)],[math.sin(-th), math.cos(th)]]).to(dev)
      wh = (R@wh)[...,0]
      w = wh[...,0]
      h = wh[...,1]
      # resize
      h = h/sy
      w = w/sx
      # shift
      y_shift=dy*hs*sy
      x_shift=dx*ws*sx
      h=h-y_shift
      w=w-x_shift
      h = h + ch
      w = w + cw

      h_normalized = (2 * h / (hs - 1)) - 1
      w_normalized = (2 * w / (ws - 1)) - 1
      coords = torch.stack((w_normalized, h_normalized), dim=-1)
      coords_unnorm = torch.stack((w, h), dim=-1)

      coords = coords[:, :, :, 0, :]
      coords_unnorm = coords_unnorm[:, :, :, 0, :]

      # Collapse the batch_size, num_tokens dimension and set num_channels=1 for grid sampling
      tgt_attn = einops.rearrange(tgt_attn, 'n h w k -> n k h w')
      tgt_attn = torch.nn.functional.grid_sample(tgt_attn.float(), coords, mode='bilinear', align_corners=False)
      tgt_attn = einops.rearrange(tgt_attn, 'n k h w -> n h w k')
    if L2: return (0.5*(attn-tgt_attn)**2).mean()
    return (attn-tgt_attn).abs().mean()

  @staticmethod
  def edge(attn, i, tgt, idxs=None, L2=False):
    attn = attn[i]
    tgt_attn = tgt[i].to(attn.device)

    tgt_attn = tgt_attn if tgt_attn is not None else attn
    if idxs is not None:
      attn = attn[...,idxs] # (1, SG_RES, SG_RES, 1)
      tgt_attn = tgt_attn[...,idxs] # (1, SG_RES, SG_RES, 1)

      
    attn_sobel = calculate_sobel(attn.squeeze(0).squeeze(-1))
    tgt_attn_sobel = calculate_sobel(tgt_attn.squeeze(0).squeeze(-1))

    attn_sobel = SelfGuidanceEdits._attn_diff_norm(attn_sobel)
    tgt_attn_sobel = SelfGuidanceEdits._attn_diff_norm(tgt_attn_sobel)

    if L2: return (0.5*(attn_sobel - tgt_attn_sobel)**2).mean()
    return (attn_sobel-tgt_attn_sobel).abs().mean()
  
  @staticmethod
  def shape(attn, i, tgt, idxs=None, norm=2, smth=False):
    attn = attn[i]
    tgt_attn = tgt[-1].to(attn.device)

    if idxs is not None:
      attn = attn[...,idxs]
      tgt_attn = tgt_attn[...,idxs]

    if smth:
      attn = smth_3(attn.permute(0,3,1,2)).permute(0,2,3,1)
      tgt_attn = smth_3(tgt_attn.permute(0,3,1,2)).permute(0,2,3,1)

    attn = SelfGuidanceEdits._attn_diff_norm(attn)
    tgt_attn = SelfGuidanceEdits._attn_diff_norm(tgt_attn)

    return (1/norm*((attn-tgt_attn))**norm).abs().mean()

  @staticmethod
  def shape_frobenius(attn, i, tgt, idxs=None, norm=2, smth=False):
    attn = attn[i]
    tgt_attn = tgt[-1].to(attn.device)

    if idxs is not None:
      attn = attn[...,idxs]
      tgt_attn = tgt_attn[...,idxs]

    if smth:
      attn = smth_3(attn.permute(0,3,1,2)).permute(0,2,3,1)
      tgt_attn = smth_3(tgt_attn.permute(0,3,1,2)).permute(0,2,3,1)

    attn = SelfGuidanceEdits._attn_diff_norm(attn)
    tgt_attn = SelfGuidanceEdits._attn_diff_norm(tgt_attn)

    return torch.sqrt(((attn-tgt_attn)**2).mean())

  @staticmethod
  def shape_focal(attn, i, tgt, idxs=None, gamma=3, smth=False):
    attn = attn[i]
    tgt_attn = tgt[-1].to(attn.device)

    if idxs is not None:
      attn = attn[...,idxs]
      tgt_attn = tgt_attn[...,idxs]

    if smth:
      attn = smth_3(attn.permute(0,3,1,2)).permute(0,2,3,1)
      tgt_attn = smth_3(tgt_attn.permute(0,3,1,2)).permute(0,2,3,1)

    attn = SelfGuidanceEdits._attn_diff_norm(attn)
    tgt_attn = SelfGuidanceEdits._attn_diff_norm(tgt_attn)

    logits = (attn - tgt_attn).abs()
    return -((logits**gamma)*torch.log((1-logits)+1e-10)).mean()

  @staticmethod
  def hard_shape(attn, i, tgt, idxs=None, norm=2, focal=False, gamma=3, smth=False):
    attn = attn[i]
    tgt_attn = tgt[-1].to(attn.device)

    if idxs is not None:
      attn = attn[...,idxs]
      tgt_attn = tgt_attn[...,idxs]

    if smth:
      attn = smth_3(attn.permute(0,3,1,2)).permute(0,2,3,1)
      tgt_attn = smth_3(tgt_attn.permute(0,3,1,2)).permute(0,2,3,1)

    attn = SelfGuidanceEdits._attn_diff_norm(attn)
    tgt_attn = SelfGuidanceEdits._attn_diff_norm(tgt_attn, hard=True, thresh=0.3)

    logits = (attn - tgt_attn).abs()
    if focal: return -((logits**gamma)*torch.log((1-logits)+1e-10)).mean()
    return (1/norm*((attn-tgt_attn))**norm).abs().mean()

  @staticmethod
  def centroid(attn, i, tgt=None, shift=(0., 0.), relative=False, idxs=None, L2=False):
    attn = attn[i]
    tgt_attn = tgt[i].to(attn.device) if tgt is not None else None

    if relative: assert tgt_attn is not None
    tgt_attn = tgt_attn if tgt_attn is not None else attn
    if idxs is not None:
      attn = attn[...,idxs]
      tgt_attn = tgt_attn[...,idxs]
    shift = torch.tensor(shift).to(attn.device)

    obs_centroid = SelfGuidanceEdits._centroid(attn)
    tgt_centroid = shift.reshape((1,) * (obs_centroid.ndim - shift.ndim) + shift.shape)
    if relative: tgt_centroid = SelfGuidanceEdits._centroid(tgt_attn) + tgt_centroid
    if L2: return (0.5*(obs_centroid - tgt_centroid)**2).mean()
    return (obs_centroid-tgt_centroid).abs().mean()

  @staticmethod
  def size(attn, i, tgt=None, relative=False, shift=(0.,), thresh=True, idxs=None, L2=False):
    attn = attn[i]
    tgt_attn = tgt[i].to(attn.device) if tgt is not None else None

    if relative: assert tgt_attn is not None
    tgt_attn = tgt_attn if tgt_attn is not None else attn
    if idxs is not None:
      attn = attn[...,idxs]
      tgt_attn = tgt_attn[...,idxs]
    shift = torch.tensor(shift).to(attn.device)
    if thresh:
      def _size(report_attn):
        attn_norm = SelfGuidanceEdits._attn_diff_norm(report_attn)
        return attn_norm.mean((-2, -3))[..., None]
    else:
      def _size(attn):
        return attn.mean((-2, -3))[..., None]

    size_obs = _size(attn)
    size_tgt = shift.reshape((1,) * (size_obs.ndim - shift.ndim) + shift.shape)
    if relative: size_tgt = _size(tgt_attn) + size_tgt
    if L2: return (0.5*(size_obs - size_tgt)**2).mean()
    return (size_obs-size_tgt).abs().mean()
